library(SingleCellExperiment)
library(scater)
library(ggplot2)
library(gridExtra)
library(ggthemes)
library(pheatmap)
library(reshape2)
library(ggpubr)
library(viridis)
testLargeTrain <- readRDS("results/pancreas_testLargeTrain.rds")
levinPBMC_level1_evaluaRes <- readRDS("results/levinPBMC_level1_evaluaRes.rds")
levinPBMC_level2_evaluaRes <- readRDS("results/levinPBMC_level2_evaluaRes.rds")
pancreas_evaluaRes <- readRDS("results/pancreas_evaluaRes.rds")

rownames(pancreas_evaluaRes$scClassify_ensemble) <- gsub("muraro2", "muraro", rownames(pancreas_evaluaRes$scClassify_ensemble))
df_pbmc <- do.call(rbind, lapply(levinPBMC_level1_evaluaRes, function(x) colMeans(x[, colnames(levinPBMC_level1_evaluaRes$scClassify)[1:6]])))
df_pancreas_easy <- do.call(rbind, lapply(pancreas_evaluaRes, function(x) colMeans(x[!rownames(x) %in% testLargeTrain, colnames(pancreas_evaluaRes$scClassify)[1:6]])))
df_pancreas_hard <- do.call(rbind, lapply(pancreas_evaluaRes, function(x) colMeans(x[rownames(x) %in% testLargeTrain, colnames(pancreas_evaluaRes$scClassify)[1:6]])))

df_pbmc2 <- do.call(rbind, lapply(levinPBMC_level2_evaluaRes, function(x) colMeans(x[, colnames(levinPBMC_level2_evaluaRes$scClassify)[1:6]])))

df_pbmc <- melt(df_pbmc)
colnames(df_pbmc) <- c("Method", "ErrClass", "value")
df_pbmc$Data <- c("PBMC (Level 1)")


df_pbmc2 <- melt(df_pbmc2)
colnames(df_pbmc2) <- c("Method", "ErrClass", "value")
df_pbmc2$Data <- c("PBMC (Level 2)")


df_pancreas_easy <- melt(df_pancreas_easy)
colnames(df_pancreas_easy) <- c("Method", "ErrClass", "value")
df_pancreas_easy$Data <- c("Pancreas (easy)")



df_pancreas_hard <- melt(df_pancreas_hard)
colnames(df_pancreas_hard) <- c("Method", "ErrClass", "value")
df_pancreas_hard$Data <- c("Pancreas (hard)")


df <- rbind(df_pancreas_easy, df_pancreas_hard, df_pbmc, df_pbmc2)


df_rev <- df
df_rev$ErrClass <- factor(as.character(df$ErrClass), levels = rev(levels(df$ErrClass)))
ggplot(df_rev, aes(x = Method, y = value, fill = ErrClass)) +
  geom_bar(stat = "identity") +
  facet_grid(Data ~ .) + 
  theme_bw() +
  theme(axis.title.y = element_blank(), legend.position = "top",
        axis.text.x = element_text(angle = 90, hjust = 1)) +
  scale_fill_manual(values = rev(plasma(11)[c(1, 2, 4, 6, 8, 10)]))

ggsave("figures/EV2A_barplot_benchmark_plot_all_methods_datasets.pdf", width = 8, height = 12)
ranks_col <- rev(tableau_color_pal(palette = "Hue Circle")(19))
ranks_col <- ranks_col[c(8, 7, 6, 9:19, 1, 2)]

pancreas_evaluaRes <- lapply(pancreas_evaluaRes, function(x) {
  rownames(x) <- gsub("muraro2", "muraro", rownames(x))
  x
})
acc_mat_pancreas <- round(do.call(cbind, lapply(pancreas_evaluaRes, function(x) x[rownames(pancreas_evaluaRes[[1]]), 9])), 2)

rownames(acc_mat_pancreas) <- rownames(pancreas_evaluaRes[[1]])

rank_mat_pancreas <- apply(acc_mat_pancreas, 1, function(x) rank(-x, ties.method = "first"))

round(rowMeans(rank_mat_pancreas), 2)
##          scClassify scClassify_ensemble             singleR               moana 
##                2.17                2.53                4.03                7.80 
##       singlecellNet              ACTINN                scVI              CHETAH 
##                5.33                7.00               11.50                5.37 
##                scID             Garnett           scmapCell        scmapCluster 
##                8.83               14.60                8.63                9.77 
##              scPred          Garnett_DE          SVM_reject              CaSTLe 
##               13.43               13.70               11.33                9.97
testLargeTrain <- gsub("muraro2", "muraro", testLargeTrain)
round(rowMeans(rank_mat_pancreas[, testLargeTrain]), 2)
##          scClassify scClassify_ensemble             singleR               moana 
##                2.00                1.79                4.36                7.86 
##       singlecellNet              ACTINN                scVI              CHETAH 
##                7.43                7.71               10.57                6.00 
##                scID             Garnett           scmapCell        scmapCluster 
##                6.07               14.93                8.50                9.29 
##              scPred          Garnett_DE          SVM_reject              CaSTLe 
##               11.43               13.57               13.86               10.64
round(rowMeans(rank_mat_pancreas[, !rownames(rank_mat_pancreas) %in% testLargeTrain]), 2)
##          scClassify scClassify_ensemble             singleR               moana 
##                2.17                2.53                4.03                7.80 
##       singlecellNet              ACTINN                scVI              CHETAH 
##                5.33                7.00               11.50                5.37 
##                scID             Garnett           scmapCell        scmapCluster 
##                8.83               14.60                8.63                9.77 
##              scPred          Garnett_DE          SVM_reject              CaSTLe 
##               13.43               13.70               11.33                9.97
df_rank_pancreas <- melt(rank_mat_pancreas)
df_rank_pancreas$Accuracy <- melt(t(acc_mat_pancreas))$value

colnames(df_rank_pancreas) <- c("Method", "Dataset", "Ranking", "Accuracy")


df_rank_pancreas$Method <- relevel(df_rank_pancreas$Method, "scClassify_ensemble")


for (i in 1:length(testLargeTrain)) {
  df_rank_pancreas$Dataset <- relevel(df_rank_pancreas$Dataset, testLargeTrain[i])
  
}


ggplot(df_rank_pancreas, aes(x = Dataset, y = Method, color = as.factor(Ranking), size = Accuracy)) +
  geom_point() +
  scale_color_manual(values = ranks_col) +
  theme_classic() +
  # scale_size(breaks = c(0.2, seq(0.5, 0.9, 0.1))) +
  theme(axis.text.x = element_text(angle = 90), text = element_text(size = 14)) +
  xlab("Dataset") + ylab("Method")

ggsave("figures/Figure2C_pancreas_pairwise_ranking.pdf", width = 12, height = 8, useDingbats = FALSE, family = "sans")

acc_mat_pbmc_level2 <- round(do.call(cbind, lapply(levinPBMC_level2_evaluaRes, "[[", 9)), 2)

rownames(acc_mat_pbmc_level2) <- rownames(levinPBMC_level2_evaluaRes[[1]])

rank_mat_pbmc_level2 <- apply(acc_mat_pbmc_level2, 1, function(x) rank(-x, ties.method = "first"))

rowMeans(rank_mat_pbmc_level2)
##          scClassify scClassify_ensemble             singleR               moana 
##            3.000000            3.000000            5.809524            8.952381 
##       singlecellNet              ACTINN                scVI              CHETAH 
##            4.380952            3.952381            5.833333            9.333333 
##                scID             Garnett           scmapCell        scmapCluster 
##            8.666667           13.452381           11.571429           12.976190 
##              scPred          Garnett_DE          SVM_reject              CaSTLe 
##           15.404762           13.428571            9.690476            6.547619
df_rank_pbmc_level2 <- melt(rank_mat_pbmc_level2)
df_rank_pbmc_level2$Accuracy <- melt(t(acc_mat_pbmc_level2))$value


colnames(df_rank_pbmc_level2) <- c("Method", "Dataset", "Ranking", "Accuracy")
df_rank_pbmc_level2$Method <- relevel(df_rank_pbmc_level2$Method, "scClassify_ensemble")


ggplot(df_rank_pbmc_level2, aes(x = Dataset, y = Method, color = as.factor(Ranking), size = Accuracy)) +
  geom_point() +
  scale_color_manual(values = ranks_col) +
  theme_classic() +
  theme(axis.text.x = element_text(angle = 90), text = element_text(size = 14)) +
  xlab("Dataset") + ylab("Method")

ggsave("figures/Figure2C_pbmc_level2_pairwise_ranking.pdf", width = 12, height = 8, useDingbats = FALSE, family = "sans")



acc_mat_pbmc_level1 <- round(do.call(cbind, lapply(levinPBMC_level1_evaluaRes, "[[", 9)), 2)

rownames(acc_mat_pbmc_level1) <- rownames(levinPBMC_level1_evaluaRes[[1]])

rank_mat_pbmc_level1 <- apply(acc_mat_pbmc_level1, 1, function(x) rank(-x, ties.method = "first"))

rowMeans(rank_mat_pbmc_level1)
##          scClassify scClassify_ensemble             singleR               moana 
##            2.952381            3.309524            6.071429            5.523810 
##       singlecellNet              ACTINN                scVI              CHETAH 
##            5.380952            4.500000            6.119048            8.333333 
##                scID             Garnett           scmapCell        scmapCluster 
##           11.142857           11.476190           11.523810           14.142857 
##              scPred          Garnett_DE          SVM_reject              CaSTLe 
##           15.023810           14.571429            8.619048            7.309524
df_rank_pbmc_level1 <- melt(rank_mat_pbmc_level1)
df_rank_pbmc_level1$Accuracy <- melt(t(acc_mat_pbmc_level1))$value

colnames(df_rank_pbmc_level1) <- c("Method", "Dataset", "Ranking", "Accuracy")
df_rank_pbmc_level1$Method <- relevel(df_rank_pbmc_level2$Method, "scClassify_ensemble")

ggplot(df_rank_pbmc_level1, aes(x = Dataset, y = Method, color = as.factor(Ranking), size = Accuracy)) +
  geom_point() +
  scale_color_manual(values = ranks_col) +
  theme_classic() +
  theme(axis.text.x = element_text(angle = 90), text = element_text(size = 14)) +
  xlab("Dataset") + ylab("Method")

ggsave("figures/Figure2C_pbmc_level1_pairwise_ranking.pdf", width = 12, height = 8, useDingbats = FALSE, family = "sans")
rownames(levinPBMC_level1_evaluaRes$scClassify) <- tolower(rownames(levinPBMC_level1_evaluaRes$scClassify))
rownames(levinPBMC_level1_evaluaRes$scClassify_ensemble) <- tolower(rownames(levinPBMC_level1_evaluaRes$scClassify_ensemble))

rownames(pancreas_evaluaRes$scClassify) <- tolower(rownames(pancreas_evaluaRes$scClassify))
rownames(pancreas_evaluaRes$scClassify) <- gsub("2", "", rownames(pancreas_evaluaRes$scClassify))
rownames(pancreas_evaluaRes$scClassify_ensemble) <- tolower(rownames(pancreas_evaluaRes$scClassify_ensemble))


rownames(levinPBMC_level2_evaluaRes$scClassify) <- tolower(rownames(levinPBMC_level2_evaluaRes$scClassify))
rownames(levinPBMC_level2_evaluaRes$scClassify_ensemble) <- tolower(rownames(levinPBMC_level2_evaluaRes$scClassify_ensemble))


levinPBMC_acc_diff <- levinPBMC_level1_evaluaRes$scClassify_ensemble[rownames(levinPBMC_level1_evaluaRes$scClassify),]$Accuracy - levinPBMC_level1_evaluaRes$scClassify$Accuracy
levinPBMC_acc_mean <- rowMeans(cbind(levinPBMC_level1_evaluaRes$scClassify_ensemble[rownames(levinPBMC_level1_evaluaRes$scClassify),]$Accuracy,
                                     levinPBMC_level1_evaluaRes$scClassify$Accuracy))

names(levinPBMC_acc_diff) <- names(levinPBMC_acc_mean) <- rownames(levinPBMC_level1_evaluaRes$scClassify)


levinPBMC_level2_acc_diff <- levinPBMC_level2_evaluaRes$scClassify_ensemble[rownames(levinPBMC_level2_evaluaRes$scClassify),]$Accuracy - levinPBMC_level2_evaluaRes$scClassify$Accuracy
levinPBMC_level2_acc_mean <- rowMeans(cbind(levinPBMC_level2_evaluaRes$scClassify_ensemble[rownames(levinPBMC_level2_evaluaRes$scClassify),]$Accuracy,
                                            levinPBMC_level2_evaluaRes$scClassify$Accuracy))

names(levinPBMC_level2_acc_diff) <- names(levinPBMC_level2_acc_mean) <- rownames(levinPBMC_level2_evaluaRes$scClassify)


pancreas_acc_diff <- pancreas_evaluaRes$scClassify_ensemble[rownames(pancreas_evaluaRes$scClassify),]$Accuracy - pancreas_evaluaRes$scClassify$Accuracy
pancreas_acc_mean <- rowMeans(cbind(pancreas_evaluaRes$scClassify_ensemble[rownames(pancreas_evaluaRes$scClassify),]$Accuracy,
                                    pancreas_evaluaRes$scClassify$Accuracy))

names(pancreas_acc_diff) <- names(pancreas_acc_mean) <- rownames(pancreas_evaluaRes$scClassify)

df_pancreas <- data.frame(diff = pancreas_acc_diff, mean = pancreas_acc_mean)
df_pbmc_level2 <- data.frame(diff = levinPBMC_level2_acc_diff, mean = levinPBMC_level2_acc_mean)
df_pbmc_level1 <- data.frame(diff = levinPBMC_acc_diff, mean = levinPBMC_acc_mean)

df_pancreas$dataset <- "Pancreas"
df_pbmc_level1$dataset <- "PBMC (level 1)"
df_pbmc_level2$dataset <- "PBMC (level 2)"


df <- rbind(df_pancreas, df_pbmc_level1, df_pbmc_level2)

wilcox.test(df$diff, alternative = "greater")
## 
##  Wilcoxon signed rank test with continuity correction
## 
## data:  df$diff
## V = 5466.5, p-value = 1.17e-11
## alternative hypothesis: true location is greater than 0
sum(df$diff>0)/length(df$diff)
## [1] 0.8070175
g_acc <- ggplot(df, aes(x = mean, y = diff, color = dataset)) +
  geom_hline(yintercept = 0, color = "brown", size = 1, alpha = 0.8) +
  geom_point(alpha = 0.8, size = 2) +
  theme_bw() +
  scale_color_brewer(palette = "Set1") +
  theme(aspect.ratio = 1, text = element_text(size = 14)) +
  xlab("Mean Accuracy Rate") +
  ylab("Accuracy Rate Difference")



g_acc

ggsave("figures/Figure1D_ensemble_accuracy_diff.pdf", width = 8, height = 6)